{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bc0db63f-8d18-4929-9535-edf58ae15e62",
   "metadata": {},
   "source": [
    "## LLaMA 2 指令微调（Alpaca-Style on Dolly-15K Dataset)\n",
    "\n",
    "示例代码关键训练要素：\n",
    "- 使用 Dolly-15K 数据集，以 Alpaca 指令风格生成训练数据\n",
    "- 以 4-bit（NF4）量化精度加载 `LLaMA 2-7B` 模型\n",
    "- 使用 QLoRA 以 `bf16` 混合精度训练模型\n",
    "- 使用 `HuggingFace TRL` 的 `SFTTrainer` 实现监督指令微调\n",
    "- 使用 Flash Attention 快速注意力机制加速训练（需硬件支持）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e97de889-4d1d-4ee3-ad57-289e2e5919d8",
   "metadata": {},
   "source": [
    "### 下载 databricks-dolly-15k 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ef235f0d-df0d-4be9-a034-af6684588dbe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/peft/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Downloading readme: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8.20k/8.20k [00:00<00:00, 12.6MB/s]\n",
      "Downloading data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13.1M/13.1M [00:02<00:00, 5.29MB/s]\n",
      "Generating train split: 15011 examples [00:00, 428978.17 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from random import randrange\n",
    " \n",
    "# 从hub加载数据集\n",
    "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4c5f398a-6dd4-4f71-a474-1b647103471a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'context', 'response', 'category'],\n",
       "    num_rows: 15011\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据集样例总数: 15011\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "61788bc1-6c19-4301-9089-14a59c23b869",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'instruction': 'When was The White Mandigos band formed?', 'context': 'The White Mandingos are a rock supergroup from Woodstock, New York consisting of rapper Murs, former Rolling Stone journalist and MTV / VH1 producer Sacha Jenkins and Bad Brains\\' bassist Darryl Jenifer.\\n\\nBiography\\nThe band was formed in late 2012 when Jenkins met up at Jenifer\\'s house in Woodstock to discover if there was anything in common between their respective favourite music genres. They considered their initial collaborations unimpressive, so Jenkins suggested collaborating with Murs, who provided lyrics.\\n\\nTheir first album, The Ghetto Is Tryna Kill Me was released in June 2013, and followed with a short tour of the eastern United States, including gigs in New York\\'s New Museum, Boston and Washington DC. The album is a concept album around Tyrone White, a young black man from a New York City housing project, who subsequently obtains a recording contract and gets a white girlfriend. Jenifer and Jenkins have described the album Tommy by The Who as an important influence. Reviewing the album, Baltimore City Paper \\'s Baynard Woods thought the group \"actually manage to do service to punk and hip hop\" and praised the band\\'s sense of humour, particularly the music video for their first single, \"My First White Girl\". Washington City Paper\\'s Marcus J Moore described the video for the group\\'s \"Warn A Brotha\" as \"a cool ode to skateboarding\".', 'response': \"The White Mandingos was formed in late 2012, and their first album 'The Guetto is Tryna Kill Me' was released in June 2013.\", 'category': 'closed_qa'}\n"
     ]
    }
   ],
   "source": [
    "# 随机抽选一个数据样例打印\n",
    "print(dataset[randrange(len(dataset))])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71040ca1-92d0-4e1e-af66-49ef2a52c3d7",
   "metadata": {},
   "source": [
    "### 以 Alpaca-Style 格式化指令数据\n",
    "\n",
    "`Alpacca-style` 格式：https://github.com/tatsu-lab/stanford_alpaca#data-release"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9d2889f7-ff56-4ee5-a852-1794c0a296ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_instruction(sample_data):\n",
    "    \"\"\"\n",
    "    Formats the given data into a structured instruction format.\n",
    "\n",
    "    Parameters:\n",
    "    sample_data (dict): A dictionary containing 'response' and 'instruction' keys.\n",
    "\n",
    "    Returns:\n",
    "    str: A formatted string containing the instruction, input, and response.\n",
    "    \"\"\"\n",
    "    # Check if required keys exist in the sample_data\n",
    "    if 'response' not in sample_data or 'instruction' not in sample_data:\n",
    "        # Handle the error or return a default message\n",
    "        return \"Error: 'response' or 'instruction' key missing in the input data.\"\n",
    "\n",
    "    return f\"\"\"### Instruction:\n",
    "Use the Input below to create an instruction, which could have been used to generate the input using an LLM. \n",
    " \n",
    "### Input:\n",
    "{sample_data['response']}\n",
    " \n",
    "### Response:\n",
    "{sample_data['instruction']}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "26167b60-f339-487a-a2dc-0cf49fac8932",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "### Instruction:\n",
      "Use the Input below to create an instruction, which could have been used to generate the input using an LLM. \n",
      " \n",
      "### Input:\n",
      "Phnom Penh was abandoned after being besieged by the Khmer Rouge on April 12th 1975\n",
      " \n",
      "### Response:\n",
      "What US embassy was abandoned on April 12th, 1975\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 随机抽选一个样例，打印 Alpaca 格式化后的样例 \n",
    "print(format_instruction(dataset[randrange(len(dataset))]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eea57c8-ffac-422c-9fa1-b4b98dd3f917",
   "metadata": {},
   "source": [
    "### 使用快速注意力（Flash Attention）加速训练\n",
    "\n",
    "检查你的 GPU 是否支持 `flash-attn` 加速：\n",
    "\n",
    "```shell\n",
    "$ python -c \"import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'\"\n",
    "\n",
    "Traceback (most recent call last):\n",
    "  File \"<string>\", line 1, in <module>\n",
    "AssertionError: Hardware not supported for Flash Attention\n",
    "```\n",
    "**运行结果：演示使用的 NVIDIA T4 硬件不支持 Flash Attention**\n",
    "\n",
    "#### 安装 flash-attn 加速包（需要GPU硬件支持）\n",
    "\n",
    "```shell\n",
    "$ MAX_JOBS=4 pip install flash-attn --no-build-isolation\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2e9f1714-9142-4506-b8e0-6f4553b6d8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f2561b57-b437-404e-851b-5fb15adfb18a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Collecting flash-attn\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/94/06f618bb338ec7203b48ac542e73087362b7750f9c568b13d213a3f181bb/flash_attn-2.5.8.tar.gz (2.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25hRequirement already satisfied: torch in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from flash-attn) (2.2.2)\n",
      "Collecting einops (from flash-attn)\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/44/5a/f0b9ad6c0a9017e62d4735daaeb11ba3b6c009d69a26141b258cd37b5588/einops-0.8.0-py3-none-any.whl (43 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: packaging in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from flash-attn) (24.0)\n",
      "Collecting ninja (from flash-attn)\n",
      "  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/6d/92/8d7aebd4430ab5ff65df2bfee6d5745f95c004284db2d8ca76dcbfd9de47/ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)\n",
      "Requirement already satisfied: filelock in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (3.13.3)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (4.10.0)\n",
      "Requirement already satisfied: sympy in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (1.12)\n",
      "Requirement already satisfied: networkx in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (2023.10.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from torch->flash-attn) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->flash-attn) (12.4.127)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.5)\n",
      "Requirement already satisfied: mpmath>=0.19 in /root/miniconda3/envs/peft/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0)\n",
      "Building wheels for collected packages: flash-attn\n",
      "  Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n",
      "\u001b[?25h  Created wheel for flash-attn: filename=flash_attn-2.5.8-cp310-cp310-linux_x86_64.whl size=120853537 sha256=53979129f883680327bf5d13027cd014e2d054f4fb5b8856916686ae315e57d6\n",
      "  Stored in directory: /root/.cache/pip/wheels/12/f9/45/14b376d8a258ceea3760bb70fcb3e005fc2fdeb537d422b22f\n",
      "Successfully built flash-attn\n",
      "Installing collected packages: ninja, einops, flash-attn\n",
      "Successfully installed einops-0.8.0 flash-attn-2.5.8 ninja-1.11.1.1\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!MAX_JOBS=4 pip install flash-attn --no-build-isolation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45a715dc-3ecd-41bb-9b2e-88c1db991e92",
   "metadata": {},
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "677fa03d-439e-4aa1-81f0-e1015d71c189",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using flash attention\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [24:24<00:00, 732.07s/it]\n",
      "Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.49it/s]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "\n",
    "# 如果硬件设备支持，成功安装 flash-attn后，将 use_flash_attention 设置为True\n",
    "use_flash_attention = True\n",
    " \n",
    "# 取消注释以使用 flash-atten\n",
    "if torch.cuda.get_device_capability()[0] >= 8:\n",
    "    from utils.llama_patch import replace_attn_with_flash_attn\n",
    "    print(\"Using flash attention\")\n",
    "    replace_attn_with_flash_attn()\n",
    "    use_flash_attention = True\n",
    " \n",
    "# 获取 LLaMA 2-7B 模型权重\n",
    "# 无需 Meta AI 审核的模型权重\n",
    "model_id = \"NousResearch/Llama-2-7b-hf\" \n",
    "# 通过 Meta AI 审核后可使用此 Model ID 下载\n",
    "# model_id = \"meta-llama/Llama-2-7b-hf\" \n",
    "\n",
    "# 使用 BnB 加载量化后的模型\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    " \n",
    "# 加载模型与分词器\n",
    "model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map=\"auto\")\n",
    "model.config.pretraining_tp = 1 \n",
    " \n",
    "# 通过对比doc中的字符串，验证模型是否在使用flash attention\n",
    "if use_flash_attention:\n",
    "    from utils.llama_patch import forward    \n",
    "    assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, \"Model is not using flash attention\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "782b1b42-97f2-4378-99d3-d582fae48e03",
   "metadata": {},
   "source": [
    "### 使用 QLoRA 配置加载 PEFT 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "50690552-1807-4e53-892a-1aad10cadaca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
    " \n",
    "# QLoRA 配置\n",
    "peft_config = LoraConfig(\n",
    "        lora_alpha=16,\n",
    "        lora_dropout=0.1,\n",
    "        r=16,\n",
    "        bias=\"none\",\n",
    "        task_type=\"CAUSAL_LM\", \n",
    ")\n",
    " \n",
    " \n",
    "# 使用 QLoRA 配置加载 PEFT 模型\n",
    "model = prepare_model_for_kbit_training(model)\n",
    "qlora_model = get_peft_model(model, peft_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "98490c6a-fa9d-460b-8362-33444efd2e8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 8,388,608 || all params: 6,746,804,224 || trainable%: 0.12433454005023165\n"
     ]
    }
   ],
   "source": [
    "qlora_model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2c0f7b6-30a3-4e10-8a4a-9840e465e590",
   "metadata": {},
   "source": [
    "### 训练超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "eb427d32-575f-4af7-af71-bbcfc0b3b730",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "\n",
    "# 演示训练参数（实际训练是设置为 False）\n",
    "demo_train = False\n",
    "output_dir = f\"models/llama-7-int4-dolly-{timestamp}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1f57027d-63cc-4199-9ae7-d04fa9db3f2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    " \n",
    "args = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    num_train_epochs=1 if demo_train else 3,\n",
    "    max_steps=100,\n",
    "    per_device_train_batch_size=3, # Nvidia T4 16GB 显存支持的最大 Batch Size\n",
    "    gradient_accumulation_steps=1 if demo_train else 4,\n",
    "    gradient_checkpointing=True,\n",
    "    optim=\"paged_adamw_32bit\",\n",
    "    logging_steps=10,\n",
    "    save_strategy=\"steps\" if demo_train else \"epoch\",\n",
    "    save_steps=10,\n",
    "    learning_rate=2e-4,\n",
    "    bf16=True,\n",
    "    max_grad_norm=0.3,\n",
    "    warmup_ratio=0.03,\n",
    "    lr_scheduler_type=\"constant\",\n",
    "    logging_dir='./logs/'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e28e91-6daa-44c9-8e1e-59abddd57697",
   "metadata": {},
   "source": [
    "### 实例化 SFTTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "53c29ebe-f962-4b2f-9905-743f63024cf0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating train split: 1158 examples [00:00, 1633.53 examples/s]\n",
      "/root/miniconda3/envs/peft/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:341: UserWarning: You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from trl import SFTTrainer\n",
    " \n",
    "# 数据集的最大长度序列（筛选后的训练数据样例数为1158）\n",
    "max_seq_length = 2048 \n",
    " \n",
    "trainer = SFTTrainer(\n",
    "    model=qlora_model,\n",
    "    train_dataset=dataset,\n",
    "    peft_config=peft_config,\n",
    "    max_seq_length=max_seq_length,\n",
    "    tokenizer=tokenizer,\n",
    "    packing=True,\n",
    "    formatting_func=format_instruction, \n",
    "    args=args,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c716498f-93c7-4071-858e-c53d270eeb2d",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "43d07ef2-d07c-41c1-8eaa-2472f289055c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/root/miniconda3/envs/peft/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "FlashAttention only support fp16 and bf16 data type",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:360\u001b[0m, in \u001b[0;36mSFTTrainer.train\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    357\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n\u001b[1;32m    358\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trl_activate_neftune(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel)\n\u001b[0;32m--> 360\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    362\u001b[0m \u001b[38;5;66;03m# After training we make sure to retrieve back the original forward pass method\u001b[39;00m\n\u001b[1;32m    363\u001b[0m \u001b[38;5;66;03m# for the embedding layer by removing the forward post hook.\u001b[39;00m\n\u001b[1;32m    364\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneftune_noise_alpha \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_trainer_supports_neftune:\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/trainer.py:1645\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1640\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m   1642\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m   1643\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m   1644\u001b[0m )\n\u001b[0;32m-> 1645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1646\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1647\u001b[0m \u001b[43m    \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1648\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1649\u001b[0m \u001b[43m    \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1650\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/trainer.py:1938\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1935\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m   1937\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1938\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1940\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m   1941\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   1942\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m   1943\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   1944\u001b[0m ):\n\u001b[1;32m   1945\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   1946\u001b[0m     tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/trainer.py:2759\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m   2756\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m   2758\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2759\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2761\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m   2762\u001b[0m     loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean()  \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/trainer.py:2784\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m   2782\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2783\u001b[0m     labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 2784\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2785\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m   2786\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m   2787\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:687\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    686\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 687\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/utils/operations.py:675\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    674\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 675\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m     14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     15\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 16\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/peft/peft_model.py:1073\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m   1062\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mforward in MPTForCausalLM does not support inputs_embeds\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m   1063\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_model(\n\u001b[1;32m   1064\u001b[0m             input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m   1065\u001b[0m             attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1070\u001b[0m             \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m   1071\u001b[0m         )\n\u001b[0;32m-> 1073\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1074\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1075\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1076\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1077\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1078\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1079\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1080\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1081\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1082\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1084\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m   1085\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1086\u001b[0m     \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:103\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 103\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:688\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    685\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m    687\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 688\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    689\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    690\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    691\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    692\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    693\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    694\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    695\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    696\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    697\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    698\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    700\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    701\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:570\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    566\u001b[0m             \u001b[38;5;28;01mreturn\u001b[39;00m module(\u001b[38;5;241m*\u001b[39minputs, output_attentions, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m    568\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m custom_forward\n\u001b[0;32m--> 570\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mutils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcheckpoint\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    571\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcreate_custom_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdecoder_layer\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    572\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    573\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    574\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    575\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    576\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    577\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    578\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m decoder_layer(\n\u001b[1;32m    579\u001b[0m         hidden_states,\n\u001b[1;32m    580\u001b[0m         attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    584\u001b[0m         use_cache\u001b[38;5;241m=\u001b[39muse_cache,\n\u001b[1;32m    585\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/_compile.py:24\u001b[0m, in \u001b[0;36m_disable_dynamo.<locals>.inner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     22\u001b[0m     \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_dynamo\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dynamo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdisable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrecursive\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.<locals>._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    487\u001b[0m     dynamo_config_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m    488\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 489\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    490\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    491\u001b[0m     set_eval_frame(prior)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:17\u001b[0m, in \u001b[0;36mwrap_inline.<locals>.inner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 17\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/utils/checkpoint.py:482\u001b[0m, in \u001b[0;36mcheckpoint\u001b[0;34m(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)\u001b[0m\n\u001b[1;32m    477\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m context_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m noop_context_fn \u001b[38;5;129;01mor\u001b[39;00m debug \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m:\n\u001b[1;32m    478\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    479\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing `context_fn` or `debug` is only supported when \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    480\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muse_reentrant=False.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    481\u001b[0m         )\n\u001b[0;32m--> 482\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mCheckpointFunction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpreserve\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    483\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    484\u001b[0m     gen \u001b[38;5;241m=\u001b[39m _checkpoint_without_reentrant_generator(\n\u001b[1;32m    485\u001b[0m         function, preserve, context_fn, determinism_check, debug, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    486\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/autograd/function.py:553\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m    551\u001b[0m     \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m    552\u001b[0m     args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m    555\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_setup_ctx_defined:\n\u001b[1;32m    556\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    557\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    558\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    559\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    560\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    561\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/utils/checkpoint.py:261\u001b[0m, in \u001b[0;36mCheckpointFunction.forward\u001b[0;34m(ctx, run_function, preserve_rng_state, *args)\u001b[0m\n\u001b[1;32m    258\u001b[0m ctx\u001b[38;5;241m.\u001b[39msave_for_backward(\u001b[38;5;241m*\u001b[39mtensor_inputs)\n\u001b[1;32m    260\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 261\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mrun_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:566\u001b[0m, in \u001b[0;36mLlamaModel.forward.<locals>.create_custom_forward.<locals>.custom_forward\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m    564\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcustom_forward\u001b[39m(\u001b[38;5;241m*\u001b[39minputs):\n\u001b[1;32m    565\u001b[0m     \u001b[38;5;66;03m# None for past_key_value\u001b[39;00m\n\u001b[0;32m--> 566\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:292\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m    289\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m    291\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 292\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    293\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    294\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    295\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    296\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    297\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    298\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    299\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    300\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m    302\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1509\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1518\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1519\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1523\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[38;5;241m=\u001b[39m module\u001b[38;5;241m.\u001b[39m_old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_old_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/gaoliang/study/LLM-quickstart/llama/utils/llama_patch.py:87\u001b[0m, in \u001b[0;36mforward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m     85\u001b[0m     x_unpad, indices, cu_q_lens, max_s \u001b[38;5;241m=\u001b[39m unpad_input(x, key_padding_mask)\n\u001b[1;32m     86\u001b[0m     x_unpad \u001b[38;5;241m=\u001b[39m rearrange(x_unpad, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnnz (three h d) -> nnz three h d\u001b[39m\u001b[38;5;124m\"\u001b[39m, three\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, h\u001b[38;5;241m=\u001b[39mnheads)\n\u001b[0;32m---> 87\u001b[0m     output_unpad \u001b[38;5;241m=\u001b[39m \u001b[43mflash_attn_varlen_qkvpacked_func\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     88\u001b[0m \u001b[43m        \u001b[49m\u001b[43mx_unpad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcu_q_lens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msoftmax_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcausal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m     89\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     90\u001b[0m     output \u001b[38;5;241m=\u001b[39m rearrange(\n\u001b[1;32m     91\u001b[0m         pad_input(rearrange(output_unpad, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnnz h d -> nnz (h d)\u001b[39m\u001b[38;5;124m\"\u001b[39m), indices, bsz, q_len),\n\u001b[1;32m     92\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb s (h d) -> b s h d\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     93\u001b[0m         h\u001b[38;5;241m=\u001b[39mnheads,\n\u001b[1;32m     94\u001b[0m     )\n\u001b[1;32m     95\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(rearrange(output, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb s h d -> b s (h d)\u001b[39m\u001b[38;5;124m\"\u001b[39m)), \u001b[38;5;28;01mNone\u001b[39;00m, past_key_value\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:893\u001b[0m, in \u001b[0;36mflash_attn_varlen_qkvpacked_func\u001b[0;34m(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs)\u001b[0m\n\u001b[1;32m    845\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mflash_attn_varlen_qkvpacked_func\u001b[39m(\n\u001b[1;32m    846\u001b[0m     qkv,\n\u001b[1;32m    847\u001b[0m     cu_seqlens,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    855\u001b[0m     return_attn_probs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    856\u001b[0m ):\n\u001b[1;32m    857\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"dropout_p should be set to 0.0 during evaluation\u001b[39;00m\n\u001b[1;32m    858\u001b[0m \u001b[38;5;124;03m    If Q, K, V are already stacked into 1 tensor, this function will be faster than\u001b[39;00m\n\u001b[1;32m    859\u001b[0m \u001b[38;5;124;03m    calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    891\u001b[0m \u001b[38;5;124;03m            pattern (negative means that location was dropped, nonnegative means it was kept).\u001b[39;00m\n\u001b[1;32m    892\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 893\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mFlashAttnVarlenQKVPackedFunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    894\u001b[0m \u001b[43m        \u001b[49m\u001b[43mqkv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    895\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcu_seqlens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    896\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_seqlen\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    897\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    898\u001b[0m \u001b[43m        \u001b[49m\u001b[43msoftmax_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    899\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcausal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    900\u001b[0m \u001b[43m        \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    901\u001b[0m \u001b[43m        \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    902\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    903\u001b[0m \u001b[43m        \u001b[49m\u001b[43mreturn_attn_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    904\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/torch/autograd/function.py:553\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m    551\u001b[0m     \u001b[38;5;66;03m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m    552\u001b[0m     args \u001b[38;5;241m=\u001b[39m _functorch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 553\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m    555\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_setup_ctx_defined:\n\u001b[1;32m    556\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    557\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    558\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    559\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstaticmethod. For more details, please see \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    560\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    561\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:290\u001b[0m, in \u001b[0;36mFlashAttnVarlenQKVPackedFunc.forward\u001b[0;34m(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax)\u001b[0m\n\u001b[1;32m    288\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m softmax_scale \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    289\u001b[0m     softmax_scale \u001b[38;5;241m=\u001b[39m qkv\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m (\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.5\u001b[39m)\n\u001b[0;32m--> 290\u001b[0m out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state \u001b[38;5;241m=\u001b[39m \u001b[43m_flash_attn_varlen_forward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    291\u001b[0m \u001b[43m    \u001b[49m\u001b[43mqkv\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    292\u001b[0m \u001b[43m    \u001b[49m\u001b[43mqkv\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    293\u001b[0m \u001b[43m    \u001b[49m\u001b[43mqkv\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    294\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcu_seqlens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    295\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcu_seqlens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    296\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_seqlen\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    297\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_seqlen\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    298\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    299\u001b[0m \u001b[43m    \u001b[49m\u001b[43msoftmax_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    300\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcausal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    301\u001b[0m \u001b[43m    \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    302\u001b[0m \u001b[43m    \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    303\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_softmax\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_softmax\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mand\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    304\u001b[0m \u001b[43m    \u001b[49m\u001b[43mblock_table\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    305\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    306\u001b[0m ctx\u001b[38;5;241m.\u001b[39msave_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)\n\u001b[1;32m    307\u001b[0m ctx\u001b[38;5;241m.\u001b[39mdropout_p \u001b[38;5;241m=\u001b[39m dropout_p\n",
      "File \u001b[0;32m~/miniconda3/envs/peft/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:86\u001b[0m, in \u001b[0;36m_flash_attn_varlen_forward\u001b[0;34m(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, block_table)\u001b[0m\n\u001b[1;32m     84\u001b[0m maybe_contiguous \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: x\u001b[38;5;241m.\u001b[39mcontiguous() \u001b[38;5;28;01mif\u001b[39;00m x\u001b[38;5;241m.\u001b[39mstride(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m x\n\u001b[1;32m     85\u001b[0m q, k, v \u001b[38;5;241m=\u001b[39m [maybe_contiguous(x) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m (q, k, v)]\n\u001b[0;32m---> 86\u001b[0m out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state \u001b[38;5;241m=\u001b[39m \u001b[43mflash_attn_cuda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvarlen_fwd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     87\u001b[0m \u001b[43m    \u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     88\u001b[0m \u001b[43m    \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     89\u001b[0m \u001b[43m    \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     90\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     91\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcu_seqlens_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     92\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcu_seqlens_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     93\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     94\u001b[0m \u001b[43m    \u001b[49m\u001b[43mblock_table\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     95\u001b[0m \u001b[43m    \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     96\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_seqlen_q\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     97\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_seqlen_k\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     98\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdropout_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     99\u001b[0m \u001b[43m    \u001b[49m\u001b[43msoftmax_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    100\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    101\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcausal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    102\u001b[0m \u001b[43m    \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    103\u001b[0m \u001b[43m    \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    104\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_softmax\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    105\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    106\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    107\u001b[0m \u001b[38;5;66;03m# if out.isnan().any() or softmax_lse.isnan().any():\u001b[39;00m\n\u001b[1;32m    108\u001b[0m \u001b[38;5;66;03m#     breakpoint()\u001b[39;00m\n\u001b[1;32m    109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state\n",
      "\u001b[0;31mRuntimeError\u001b[0m: FlashAttention only support fp16 and bf16 data type"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd5eeb5e-2082-4ead-b4a2-ebedb06333f9",
   "metadata": {},
   "source": [
    "### 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e874176b-1b1a-4264-b24e-5308c773e81d",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d39256f7-6b10-4f6d-8a99-c96a512604ec",
   "metadata": {},
   "source": [
    "### 模型推理（测试）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e38d38d-a675-40a1-aa0d-b78e90865162",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
